from BasicLibraries import *
import Functions.Miscellaneous.utils as ut
import Make_universe
ONEDRIVE = ''

class DataStructuring():
    def __init__(self):
        pass

    def make_empirical_matrix(self, dx, actions, sdgs_lst, output='binary'):
        tmp_dt = dx.copy().reset_index()
        
        sdgs = ['SDG '+str(S) for S in sdgs_lst]
        initiatives = np.ravel([[a+' - '+s for s in sdgs] for a in actions])
        if output == 'binary':
            print('Binarisation')
            quantiles = tmp_dt[initiatives].quantile(0.75)
            tmp_dt[initiatives] = tmp_dt[initiatives][(tmp_dt[initiatives] >= quantiles) & (tmp_dt[initiatives] > 0)].replace(np.nan, 0).clip(0,1).round()
        empirical=[]
        A = tmp_dt[initiatives].unstack().reset_index()
        A['action'] = A['level_0'].apply(lambda x: x.split('-')[0])
        A['sdg'] = A['level_0'].apply(lambda x: x.split('-')[1])        
        M = A[[0, 'action', 'sdg']].groupby(['action', 'sdg']).sum()
        emp = M.reset_index().pivot('action', 'sdg', 0)
        emp = emp.transpose()          
        emp.index = emp.reset_index()['sdg'].apply(lambda x: x.split('SDG ')[1])
        return emp

    def get_the_data(self, population_type):
        dt  = pd.read_csv(population_type, low_memory= False)
        dt = dt.rename(columns = {'GICS_level_1_x': 'GICS_level_1'})
        dt['rfyear'] = dt['rfyear'].astype(str)
        dt['Tangibility'] = dt['ppent_usd']/dt['at_usd']
        dt['buyback_rate'] = (dt['iss_eq_usd_ref']/1000000)
        dt['Investment_log'] = dt['Total_invested_capital_BS1'].apply(np.log)
        dt = dt.rename(columns = {'DirectControl': 'DirectControl_intensity'})
        
        return dt
    def make_quartile_statistics(self, dt, vars_, vars_names):
        loc_dt = dt.copy()
        loc_dt[['DirectControl_intensity']] = loc_dt[['DirectControl_intensity']].apply(np.log).round(2)
        table_ = loc_dt[['gvkey'] + vars_].groupby('gvkey').median().describe().round(2).loc[['25%', '50%', '75%']]
        table_.columns = vars_names
        table_ = table_.round(2).transpose()
        table_2 = loc_dt[['gvkey'] + vars_].groupby('gvkey').median().describe().round(2).loc[['mean', 'std']]
        table_2.columns = vars_names
        table_2 = table_2.round(2).transpose()
        table_['50% (mean)'] = table_['50%'].astype(str)+'('+table_2['mean'].round(2).astype(str)+')'
        table_ = table_.drop(columns = ['50%'])
        table_ = table_[['25%', '50% (mean)', '75%']]
        print(table_.to_latex(escape=False).replace('$MM', '\$MM').replace('%', '\%'))
        
    def make_main_table(self, dt, vars_, vars_names):
        #====
        A = dt[vars_+['rfyear']].groupby('rfyear').mean().round(2)
        B = dt[['gvkey', 'rfyear']].groupby('rfyear').count().astype(int)
        C = dt[vars_].mean()
        C['gvkey'] = len(dt.gvkey.unique())
        X = pd.concat((A, B), axis = 1)
        X = pd.concat((X, pd.DataFrame(C, columns = ['Average']).transpose())).round(2)
        X[['DirectControl_intensity']] = X[['DirectControl_intensity']].apply(np.log).round(2)
        X.columns = vars_names
        X['Firms'] = X['Firms'].astype(int)
        #====
        A = dt[vars_+['rfyear']].groupby('rfyear').median().round(2)
        B = dt[['gvkey', 'rfyear']].groupby('rfyear').count().astype(int)
        C = dt[vars_].median()
        C['gvkey'] = len(dt.gvkey.unique())
        X2 = pd.concat((A, B), axis = 1)
        X2 = pd.concat((X2, pd.DataFrame(C, columns = ['Average']).transpose())).round(2)
        X2[['DirectControl_intensity']] = X2[['DirectControl_intensity']].apply(np.log).round(2)
        X2.columns = vars_names
        X2 = X2.astype(str)
        X2['Firms'] = ['']*len(X2)
        X.index = [int(float(X.index[i])) for i in range(len(X)-1)] + ['Average']
        print(X.to_latex (escape = False).replace('Average', '\hline Average').replace('$MM', '\$MM').replace('%', '\%'))
        
    def table_count_by_dim(self, dx, cmp = [], dim='GICS_level_1'):
        L = dx.copy()
        L['rfyear'] = L['rfyear'].astype(float).astype(int)
        if len(cmp) > 0:
            L = L.drop(columns = ['GICS_level_1']).merge(cmp, on = 'gvkey')
        sc = L[[dim, 'rfyear', 'gvkey']].groupby([dim, 'rfyear']).nunique().reset_index()
        sc = sc.pivot('rfyear', dim,  'gvkey')
        tc = L[[dim, 'gvkey']].replace(np.nan, 0).groupby(dim).nunique().transpose()
        tc.index = ['Total']
        x1 = pd.concat((sc, tc))
        yt = L[['rfyear', 'gvkey']].groupby('rfyear').nunique().astype(int)
        yt.columns = ['Total']
        x2 = pd.concat((x1, yt), axis = 1).replace([np.nan], [tc.sum().sum()]).astype(int)

        return x2
    def SampleStat_of_convergent_firms(self, dx, vars_, vars_names):
        L = dx.copy()
        process_type = '1_0_opt_GA_process_data_0724_2021.pckl'
        _, _, _, _, _, res, _, _, _, _, _ = pickle.load(open(process_type, 'rb'))
        res = res[res['OptimalPerformance'] > res['ExpectedPerformance'] ]
       
        ## Note that you are taking only the company-year observations, not just the company observations
        subsample= L[L.mrg.isin(res.mrg.unique())]       
        print(len(subsample.gvkey.unique()))
        self.make_quartile_statistics(subsample, vars_, vars_names)

